/*
 *  algorithms.cpp
 *  Classifier
 *
 *  Daniel Wojcik
 *
 */

#include <iostream>
#include "algorithms.h"

//Scoring function for a given term
//Was dCount, now requires normalization by
//object-level term counts in caller.
double termToDouble(TermShort& old)
{
	return scalar * old.idf * (old.count);// /(double)old.dCount);
}
double termToDouble(TermStat& old)
{
	return scalar * old.idf * (old.count);// /(double)old.dCount);
}
double termToDouble(TermStat& old, std::string cls)
{
	return scalar * old.idf * (old.cCounts[cls]);// / (double)old.dCount);
}


//Dispatcher function
void classify(DocItem& dcmt, Globals& global)
{
	switch (clusMet)
	{
		case 0: cClassify(dcmt, global); break;
		case 1: nClassify(dcmt, global); break;
		case 2: svmClassify(dcmt, global); break;
		case 3: ejcClassify(dcmt, global); break;
	}
}

//Dispatcher function
void cluster(Globals& global)
{
	switch (clusMet)
	{
		case 0: break;
		case 1: nearCluster(global); break;
		case 2: svmCluster(global); break;
		case 3: break;
	}
}

void updateWeights(DocItem* dcmt, Globals& global)
{
	std::string classf = dcmt->classification[0];
	if (global.docClasses.find(classf) == global.docClasses.end())
	{
		global.docClasses[classf].seen = true;
		global.docClasses[classf].dCount = 1;
	}
	else
		global.docClasses[classf].dCount++;
	
	std::map<std::string, TermStat>::iterator itr = global.gTerms.begin();
	while (itr != global.gTerms.end())
	{
		itr->second.idf = log10((float)global.docCount / itr->second.dCount);
		itr->second.cCounts[classf]+= dcmt->getCount(itr->first);
		
		itr++;
	}
}

void characterizeClasses(Globals& global)
{
	cull(global);
	std::cout<< "Characterizing\n";
	
	std::map<std::string, ClassStat>::iterator itr = global.docClasses.begin();
	while (itr != global.docClasses.end())
	{
		std::string classf = itr->first;
		itr->second.termSize = 0;
		itr->second.termCount = 0;
		itr->second.charTerms.clear();
		
		std::map<std::string, TermStat>::iterator tItr = global.gTerms.begin();
		while (tItr != global.gTerms.end())
		{
			TermStat term = tItr->second;
			if (itr->second.termSize < topK)
			{
				itr->second.charTerms[tItr->first].count = term.cCounts[classf];
				itr->second.charTerms[tItr->first].dCount = term.dCount;
				itr->second.charTerms[tItr->first].idf = term.idf;
				itr->second.termSize++;
				itr->second.termCount+= term.cCounts[classf];
			}
			else
			{
				TermShort min;
				min.count = 0;
				min.dCount = 0;
				min.idf = 0.0;
				std::string minStr;
				
				std::map<std::string, TermShort>::iterator cItr = itr->second.charTerms.begin();
				while (cItr != itr->second.charTerms.end())
				{
					double minScore = termToDouble(min);
					double curScore = termToDouble(cItr->second);
					if (min.dCount == 0 || curScore < minScore)
					{
						min.count = cItr->second.count;
						min.idf = cItr->second.idf;
						min.dCount = cItr->second.dCount;
						minStr = cItr->first;
					}
					cItr++;
				}
				
				double minScore = termToDouble(min);
				double curScore = termToDouble(term,classf);
				if (minScore < curScore)
				{
					itr->second.charTerms.erase(minStr);
					itr->second.charTerms[tItr->first].count = term.cCounts[classf];
					itr->second.charTerms[tItr->first].dCount = term.dCount;
					itr->second.charTerms[tItr->first].idf = term.idf;
				}
			}
			tItr++;
		}
		itr++;
	}
}

//Generate characteristic terms of the document.
//Could probably do this better given the difference
//between value types in the desired maps.
void characterizeDocument(DocItem& dcmt, Globals& global)
{
	dcmt.charTerms.clear();
	unsigned int termSize = 0;
	
	std::map<std::string, TermStat>::iterator tItr = global.gTerms.begin();
	while (tItr != global.gTerms.end())
	{
		TermShort term;
		term.count = dcmt.getCount(tItr->first);
		term.dCount = 1;
		term.idf = tItr->second.idf;
		
		if (termSize < topK)
		{
			dcmt.charTerms[tItr->first].count = term.count;
			dcmt.charTerms[tItr->first].dCount = term.dCount;
			dcmt.charTerms[tItr->first].idf = term.idf;
			termSize++;
		}
		else
		{
			TermShort min;
			min.count = 0;
			min.dCount = 0;
			min.idf = 0.0;
			std::string minStr;
			
			std::map<std::string, TermShort>::iterator cItr = dcmt.charTerms.begin();
			while (cItr != dcmt.charTerms.end())
			{
				double minScore = termToDouble(min);
				double curScore = termToDouble(cItr->second);
				if (min.dCount == 0 || curScore < minScore)
				{
					min.count = cItr->second.count;
					min.idf = cItr->second.idf;
					min.dCount = cItr->second.dCount;
					minStr = cItr->first;
				}
				cItr++;
			}
			
			double minScore = termToDouble(min);
			double curScore = termToDouble(term);
			if (minScore < curScore)
			{
				dcmt.charTerms.erase(minStr);
				dcmt.charTerms[tItr->first].count = term.count;
				dcmt.charTerms[tItr->first].dCount = term.dCount;
				dcmt.charTerms[tItr->first].idf = term.idf;
			}
		}
		tItr++;
	}
}

//If too far away, make new class?
//Doesn't do this currently, but should if clustering is enabled.
void cClassify(DocItem& dcmt, Globals& global)
{
	std::map<std::string, ClassStat>::iterator itr = global.docClasses.begin();
	std::string classf = "-1";
	double maxW = -1;
	
	while (itr != global.docClasses.end())
	{
		double w = 0;
		std::map<std::string, TermShort>::iterator cItr = itr->second.charTerms.begin();
		while (cItr != itr->second.charTerms.end())
		{
			//Weight function
			//Currently set to use normalization by term, which doesn't work for this method!
			double cw = cItr->second.count / (double)itr->second.termCount;
			double dw = dcmt.getCount(cItr->first) / (double)dcmt.termCount;
			w+= scalar * cItr->second.idf * abs(cw - dw);
			cItr++;
		}
		if (classf.compare("-1") == 0 || maxW < w)
		{
			classf = itr->first;
			maxW = w;
		}
		itr++;
	}

	dcmt.classification[0] = classf;
	
	std::cout << "Document : " << dcmt.classification[0] << " " << maxW << "\n";
	std::cout << "Real Class : " << dcmt.realClass[0] << "\n";
}

void nClassify(DocItem& dcmt, Globals& global)
{
	std::map<std::string, ClassStat>::iterator itr = global.docClasses.begin();
	std::string classf = "-1";
	double maxW = -1, p = 0;
	
	while (itr != global.docClasses.end())
	{
		double w = 0;
		std::map<std::string, TermShort>::iterator cItr = itr->second.charTerms.begin();
		while (cItr != itr->second.charTerms.end())
		{
			//Weight function
			//Currently uses normalization by term, which doesn't work for this method!
			double cw = cItr->second.count / (double)itr->second.termCount;
			double dw = dcmt.getCount(cItr->first) / (double)dcmt.termCount;
			w+= scalar * cItr->second.idf * abs(cw - dw);
			p+= cItr->second.idf * dcmt.getCount(cItr->first);
			cItr++;
		}
		if (classf.compare("-1") == 0 || maxW < w)
		{
			classf = itr->first;
			maxW = w;
		}
		itr++;
	}

	dcmt.classification[0] = classf;
	std::pair<std::string,double> classes[classTypes];
	classes[0].first = classf;
	classes[0].second = maxW;
	unsigned int t = 1;
	
	//Next closest classes come from this cluster
	Cluster cls = global.clusters[global.docClasses[classf].cluster];
	std::list<std::string>::iterator dItr = cls.classes.begin();
	while (dItr != cls.classes.end())
	{
		std::string c = *dItr;
		if (c.compare(classf) == 0)
			continue;
		double w = 0;
		std::map<std::string, TermShort>::iterator cItr = global.docClasses[c].charTerms.begin();
		while (cItr != global.docClasses[c].charTerms.end())
		{
			//Weight function
			//Currently uses normalization by term, which doesn't work for this method!
			double cw = cItr->second.count / (double)global.docClasses[c].termCount;
			double dw = dcmt.getCount(cItr->first) / (double)dcmt.termCount;
			w+= scalar * cItr->second.idf * abs(cw - dw);
			cItr++;
		}
		
		unsigned int i = 1;
		while (i < t && i < classTypes)
		{
			if (w > classes[i].second)
			{
				std::string temp1 = classes[i].first;
				double temp2 = classes[i].second;
				classes[i].first = c;
				classes[i].second = w;
				c = temp1;
				w = temp2;
			}
			i++;
		}
		if (i < classTypes)
		{
			classes[i].first = c;
			classes[i].second = w;
			t++;
		}
		
		dItr++;
	}
	//Assign classes to document
	for (unsigned int i = 0; i < t; i++)
	{
		dcmt.classification[i] = classes[i].first;
		std::cout<< "Assigned class [" << i << "] : ";
		std::cout<< classes[i].first << " : " << classes[i].second;
		std::cout<< " : " << dcmt.realClass[0] << "\n";
	}
}

//Do clustering on the classes
//Allows it to classify based on computed class data
//Rather than just compute the class from item data
//Also enables program defined classes and superclasses
//Slow, so only done when saving data to file
//Could be sped up by only focusing on characteristic words
//for each document, but finding those could be tricky.
void nearCluster(Globals& global)
{
	cull(global);
	std::cout<<"Clustering\n";
	//Should to a relevancy check as it goes along.
	//Probably the best place to cull excess terms
	//from both the item and class lists.
	std::map<std::string, ClassStat>::iterator itr = global.docClasses.begin();
	while (itr != global.docClasses.end())
	{
		unsigned int k = 0;
		std::cout<<itr->first <<"\n";
		itr->second.point = 0;
		bool findPoint = false;
		
		std::map<std::string, ClassStat>::iterator cItr = global.docClasses.begin();
		while (cItr != global.docClasses.end())
		{
			if (cItr != itr)
			{
				//L1 distance metric
				std::map<std::string, TermShort>::iterator tItr = itr->second.charTerms.begin();
				double d = 0;
				while (tItr != itr->second.charTerms.end())
				{
					if (findPoint)
						itr->second.point+= tItr->second.count * tItr->second.idf;
						
					std::map<std::string, TermShort>::iterator tItr2 = cItr->second.charTerms.find(tItr->first);
					if (tItr2 != cItr->second.charTerms.end())
					{
						double c1 = tItr->second.count / (double)itr->second.termCount;
						double c2 = tItr2->second.count / (double)cItr->second.termCount;
						d+= abs(c1 - c2) * tItr->second.idf;
					}
					else
					{
						//Some high value to weight against terms uncommon to them.
						d+= penalty;
					}
					tItr++;
				}
				findPoint = true;
				
				if (k < nearK)
				{
					if (d < maxD)
					{
						std::pair<std::string, double> p;
						p.first = cItr->first;
						p.second = d;
						itr->second.neighbors[k++] = p;
					}
				}
				else
				{
					int r = 0;
					double rd = itr->second.neighbors[r].second;
					for (unsigned int i = 1; i < k; i++)
					{
						if (rd < itr->second.neighbors[i].second)
						{
							r = i;
							rd = itr->second.neighbors[i].second;
						}
					}
					if (rd > d)
					{
						itr->second.neighbors[r].second = d;
						itr->second.neighbors[r].first = cItr->first;
					}
				}
			}
			cItr++;
		}
		itr->second.cluster = 0;
		itr++;
	}
	
	//BFS search to determine the clusters
	std::cout<<"Clustering2\n";
	itr = global.docClasses.begin();
	unsigned int cls = 1;
	std::queue<std::string> bfsQ;
	std::list<std::pair<unsigned int, unsigned int> > mergeL;
	
	while (itr != global.docClasses.end())
	{
		//BFS
		if (itr->second.cluster != 0)
		{
			itr++;
			continue;
		}
		
		itr->second.cluster = cls;
		for (unsigned int i = 0; i < nearK; i++)
		{
			std::string c = itr->second.neighbors[i].first;
			if (global.docClasses[c].cluster != 0)
			{
				std::pair<unsigned int, unsigned int> p;
				std::list<std::pair<unsigned int, unsigned int> >::iterator mItr = mergeL.begin();
				while (mItr != mergeL.end())
				{
					if (mItr->first == cls)
					{
						if (mItr->second == global.docClasses[c].cluster)
							break;
						else
						{
							p.first = global.docClasses[c].cluster;
							p.second = mItr->second;
							mergeL.push_back(p);
							break;
						}
					}
					mItr++;
				}
				if (mItr == mergeL.end())
				{
					p.first = cls;
					p.second = global.docClasses[c].cluster;
					mergeL.push_back(p);
				}
			}
			else
				bfsQ.push(c);
		}
		while (!bfsQ.empty())
		{
			std::string c = bfsQ.front();
			bfsQ.pop();
			if (global.docClasses[c].cluster != 0)
			{
				std::pair<unsigned int, unsigned int> p;
				std::list<std::pair<unsigned int, unsigned int> >::iterator mItr = mergeL.begin();
				while (mItr != mergeL.end())
				{
					if (mItr->first == cls)
					{
						if (mItr->second == global.docClasses[c].cluster)
							break;
						else
						{
							p.first = global.docClasses[c].cluster;
							p.second = mItr->second;
							mergeL.push_back(p);
							break;
						}
					}
					mItr++;
				}
				if (mItr == mergeL.end())
				{
					p.first = cls;
					p.second = global.docClasses[c].cluster;
					mergeL.push_back(p);
				}
				continue;
			}
			
			double newPoint = global.clusters[cls].meanPoint * global.clusters[cls].count + global.docClasses[c].point;
			double newClass = global.clusters[cls].meanClass * global.clusters[cls].count;// + c;
			global.clusters[cls].count++;
			global.clusters[cls].meanPoint = newPoint / global.clusters[cls].count;
			global.clusters[cls].meanClass = newClass /global. clusters[cls].count;
			global.docClasses[c].cluster = cls;
			for (unsigned int i = 0; i < nearK; i++)
				bfsQ.push(global.docClasses[c].neighbors[i].first);
		}
		
		cls++;
		itr++;
	}
	
	if (mergeClusters)
	{
		std::cout<<"Clustering3\n";
		while (!mergeL.empty())
		{
			std::pair<unsigned int, unsigned int> p = mergeL.back();
			mergeL.pop_back();
			
			double newPoint = global.clusters[p.second].meanPoint * global.clusters[p.second].count;
			newPoint+= global.clusters[p.first].meanPoint * global.clusters[p.first].count;
			double newClass = global.clusters[p.second].meanClass * global.clusters[p.second].count;
			newClass+= global.clusters[p.first].meanClass * global.clusters[p.first].count;
			global.clusters[p.second].count+= global.clusters[p.first].count;
			global.clusters[p.second].meanPoint = newPoint / global.clusters[p.second].count;
			global.clusters[p.second].meanClass = newClass / global.clusters[p.second].count;
			
			itr = global.docClasses.begin();
			while (itr != global.docClasses.end())
			{
				if (itr->second.cluster == p.first)
					itr->second.cluster = p.second;
				itr++;
			}
			
			global.clusters.erase(p.first);
		}
	}
	
	//Add to cluster objects
	std::map<unsigned int, Cluster>::iterator clsItr = global.clusters.begin();
	while (clsItr != global.clusters.end())
	{
		clsItr->second.classes.clear();
		clsItr++;
	}
	
	itr = global.docClasses.begin();
	while (itr != global.docClasses.end())
	{
		global.clusters[itr->second.cluster].classes.push_back(itr->first);
		itr++;
	}
		
	//clusterPrint(global);
}

//Do comparison for each class against all others
//Use probabalistic decisions to weight each class
//Perhaps if something has roughly equal probabilities
//In multiple classes, it goes in each?
void svmClassify(DocItem& dcmt, Globals& global)
{
	std::map<std::string,double> classProb;
	sample_type sample;
	
	dcmt.classification[0] = "-1";
	
	characterizeDocument(dcmt, global);
	std::map<std::string,TermShort>::iterator dItr = dcmt.charTerms.begin();
	while (dItr != dcmt.charTerms.end())
	{
		sample[dItr->first] = termToDouble(dItr->second) / dcmt.termCount;
		dItr++;
	}
	
	std::pair<std::string,double> classes[classTypes];
	classes[0].first = "-1";
	classes[0].second = -1;
	int t = 0;
	
	std::map<std::string,ClassStat>::iterator itr = global.docClasses.begin();
	while (itr != global.docClasses.end())
	{
		//Classify according to probabilistic decision function
		//Keep track of probabilities, pick highest classTypes
		//as long as they are above a certain threshhold.
		double p = itr->second.svmTrainer(sample);
		if (p > 0)
		{
			classProb[itr->first] = p;
			std::pair<std::string,double> mp;
			mp.first = itr->first;
			mp.second = p;
			
			unsigned int i = 0;
			while (i < t) //Keeps the classes sorted by max probability
			{
				if (p > classes[i].second)
				{
					std::pair<std::string,double> temp;
					temp.first = classes[i].first;
					temp.second = classes[i].second;
					classes[i].first = mp.first;
					classes[i].second = mp.second;
					mp.first = temp.first;
					mp.second = temp.second;
				}
				i++;
			}
			if (i < classTypes)
			{
				classes[i].first = mp.first;
				classes[i].second = mp.second;
				t++;
			}
		}
		
		itr++;
	}
	
	if (classes[0].second <= 0)
	{
		//Just output -1 as the class
		t++;
	}
	
	//Assign classes to document
	for (unsigned int i = 0; i < t; i++)
	{
		dcmt.classification[i] = classes[i].first;
		std::cout<< "Assigned class [" << i << "] : ";
		std::cout<< classes[i].first << " : " << classes[i].second;
		std::cout<< " : " << dcmt.realClass[0] << "\n";
	}
}

//Doesn't actually cluster anything ;>_>
//Instead, this handles the training of the svm
//decision functions. Unfortunately, it basically
//has to start from scratch each time, but that's
//what the other clustering functions do too.
void svmCluster(Globals& global)
{	
	std::cout<<"SVM Clustering\n";
	std::map<std::string,ClassStat>::iterator itr = global.docClasses.begin();
	while (itr != global.docClasses.end())
	{
		//Reset SVM Trainer, ensure specified parameters
		itr->second.svmTrainer.clear();
		itr->second.svmTrainer.set_lambda(lambda);
		itr->second.svmTrainer.set_tolerance(tol);
		itr->second.svmTrainer.set_max_num_sv(maxVect);
		
		sample_type sample;
		
		std::map<std::string,TermShort>::iterator tItr = itr->second.charTerms.begin();
		while (tItr != itr->second.charTerms.end())
		{
			sample[tItr->first] = termToDouble(tItr->second) / itr->second.termCount;
			tItr++;
		}
		
		std::map<std::string,ClassStat>::iterator cItr = global.docClasses.begin();
		while (cItr != global.docClasses.end())
		{
			if (itr->first == cItr->first)
				itr->second.svmTrainer.train(sample,1);
			else
			{
				sample_type cSample;
				tItr = cItr->second.charTerms.begin();
				while (tItr != cItr->second.charTerms.end())
				{
					cSample[tItr->first] = termToDouble(tItr->second) / cItr->second.termCount;
					tItr++;
				}
				itr->second.svmTrainer.train(cSample,-1);
			}
				
			cItr++;
		}
	
		itr++;
	}
}

//Extended Jaccard Coefficients
//Basically, Tanimoto Coefficients
//Similarity metric for attribute vectors
void ejcClassify(DocItem& dcmt, Globals& global)
{
	//Need to do dot product of sparse vectors...
	//If they don't share the term, it's 0.
	//That makes the dot product of two vectors
	//with no common terms 0, or orthogonal.
	double dm = 0;
	
	characterizeDocument(dcmt, global);
	std::map<std::string,TermShort>::iterator dItr = dcmt.charTerms.begin();
	while (dItr != dcmt.charTerms.end())
	{
		double d = termToDouble(dItr->second) / dcmt.termCount;
		dm+= d*d;
		dItr++;
	}
	
	std::pair<std::string,double> results[classTypes];
	unsigned int t = 0;
	
	std::map<std::string,ClassStat>::iterator itr = global.docClasses.begin();
	while (itr != global.docClasses.end())
	{
		double cm = 0;
		double dot = 0;
		std::pair<std::string,double> result;
		
		std::map<std::string,TermShort>::iterator tItr = itr->second.charTerms.begin();
		while (tItr != itr->second.charTerms.end())
		{
			double d1 = termToDouble(tItr->second) / itr->second.termCount;
			dItr = dcmt.charTerms.find(tItr->first);
			double d2 = 0;
			if (dItr != dcmt.charTerms.end())
				d2 = termToDouble(dItr->second) / dcmt.termCount;
			dot+= d1*d2;
			cm+= d1*d1;
			
			tItr++;
		}
		
		result.first = itr->first;
		result.second = dot / (cm + dm - dot);
		
		//Compare to current results
		unsigned int i = 0;
		while (i < t)
		{
			if (result.second > results[i].second)
			{
				std::string temp1 = results[i].first;
				double temp2 = results[i].second;
				results[i].first = result.first;
				results[i].second = result.second;
				result.first = temp1;
				result.second = temp2;
			}
			i++;
		}
		if (i < classTypes && result.second > 0)
		{
			results[i].first = result.first;
			results[i].second = result.second;
			t++;
		}
		
		itr++;
	}
	
	//Assign classes to document
	for (unsigned int i = 0; i < t; i++)
	{
		dcmt.classification[i] = results[i].first;
		std::cout<< "Assigned class [" << i << "] : ";
		std::cout<< results[i].first << " : " << results[i].second;
		std::cout<< " : " << dcmt.realClass[0] << "\n";
	}
	if (t == 0)
	{
		dcmt.classification[0] = "-1";
		std::cout<< "Assigned class [" << 0 << "] : ";
		std::cout<< dcmt.classification[0] << " : ";
		std::cout<< dcmt.realClass[0] << "\n";
	}
}

void clusterPrint(Globals& global)
{
	std::map<std::string, ClassStat>::iterator itr = global.docClasses.begin();
	while (itr != global.docClasses.end())
	{
		std::cout<< itr->first << " : " << itr->second.cluster << "\n";
		itr++;
	}
}

void classPrint(Globals& global)
{
	std::map<std::string, ClassStat>::iterator itr = global.docClasses.begin();
	while (itr != global.docClasses.end())
	{
		std::cout << "Class " << itr->first << "\n";
		std::map<std::string, TermShort>::iterator tItr = itr->second.charTerms.begin();
		while (tItr != itr->second.charTerms.end())
		{
			std::cout << tItr->first << " " << tItr->second.count << " ";
			std::cout << tItr->second.dCount << " " << tItr->second.idf << "\n";
			tItr++;
		}
		std::cout << "\n";
		itr++;
	}
}

void cull(Globals& global)
{
	std::cout << "Culling\n";
	std::map<std::string, TermStat>::iterator itr = global.gTerms.begin();
	while (itr != global.gTerms.end())
	{
		TermStat term = itr->second;
		double tf = term.count; // (double)term.dCount;
		double mtf = minKeep * global.docCount / (double)supportScale;
		if (term.idf < minIDF || tf < mtf)
			global.gTerms.erase(itr++);
		else
			itr++;
	}
}
